{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `env_func`\n",
    "\n",
    "源码：`tvm/src/ir/env_func.cc`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `EnvFuncNode` 和 `EnvFunc`\n",
    "\n",
    "```c++\n",
    "/*!\n",
    " * \\brief A serializable function backed by TVM's global environment.\n",
    " *\n",
    " * This is a wrapper to enable serializable global PackedFunc.\n",
    " * An EnvFunc is saved by its name in the global registry\n",
    " * under the assumption that the same function is registered during load.\n",
    " * \\sa EnvFunc\n",
    " */\n",
    "class EnvFuncNode : public Object {\n",
    " public:\n",
    "  /*! \\brief Unique name of the global function */\n",
    "  String name;\n",
    "  /*! \\brief The internal packed function */\n",
    "  runtime::PackedFunc func;\n",
    "  /*! \\brief constructor */\n",
    "  EnvFuncNode() {}\n",
    "\n",
    "  void VisitAttrs(AttrVisitor* v) { v->Visit(\"name\", &name); }\n",
    "\n",
    "  bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {\n",
    "    // name uniquely identifies the env function.\n",
    "    return name == other->name;\n",
    "  }\n",
    "\n",
    "  void SHashReduce(SHashReducer hash_reduce) const {\n",
    "    // Name uniquely identifies the env function.\n",
    "    hash_reduce(name);\n",
    "  }\n",
    "\n",
    "  static constexpr const char* _type_key = \"EnvFunc\";\n",
    "  static constexpr bool _type_has_method_sequal_reduce = true;\n",
    "  static constexpr bool _type_has_method_shash_reduce = true;\n",
    "  TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);\n",
    "};\n",
    "```\n",
    "\n",
    "`EnvFuncNode` 是继承自 `Object` 的类，它包含字符串类型的成员变量 `name` 和 `runtime::PackedFunc` 类型的成员变量 `func`。这个类的主要目的是作为 TVM 全局环境的包装器，使得函数可以被序列化。在加载时，通过名称在全局注册表中查找相同的函数。此外，它还提供了一些方法来访问和操作这些成员变量。例如，`VisitAttrs` 方法允许访问 `name` 属性，而 `SEqualReduce` 和 `SHashReduce` 方法则用于比较两个 `EnvFuncNode` 对象是否相等以及计算它们的哈希值。在类的声明中，还使用了一些宏来定义一些常量和类型信息。例如，`_type_key` 常量被定义为 `\"EnvFunc\"`，表示这个类的类型键； `_type_has_method_sequal_reduce` 和 `_type_has_method_shash_reduce` 常量被定义为 `true`，表示这个类支持相等性和哈希性的计算方法。最后，`TVM_DECLARE_FINAL_OBJECT_INFO` 宏用于声明这个类的最终对象信息。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "/*!\n",
    " * \\brief Managed reference to EnvFuncNode.\n",
    " * \\sa EnvFuncNode\n",
    " */\n",
    "class EnvFunc : public ObjectRef {\n",
    " public:\n",
    "  EnvFunc() {}\n",
    "  explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}\n",
    "  /*! \\return The internal global function pointer */\n",
    "  const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }\n",
    "  /*!\n",
    "   * \\brief Invoke the function.\n",
    "   * \\param args The arguments\n",
    "   * \\returns The return value.\n",
    "   */\n",
    "  template <typename... Args>\n",
    "  runtime::TVMRetValue operator()(Args&&... args) const {\n",
    "    const EnvFuncNode* n = operator->();\n",
    "    ICHECK(n != nullptr);\n",
    "    return n->func(std::forward<Args>(args)...);\n",
    "  }\n",
    "  /*!\n",
    "   * \\brief Get a global function based on the name.\n",
    "   * \\param name The name of the global function.\n",
    "   * \\return The created global function.\n",
    "   * \\note The function can be unique\n",
    "   */\n",
    "  TVM_DLL static EnvFunc Get(const String& name);\n",
    "  /*! \\brief specify container node */\n",
    "  using ContainerType = EnvFuncNode;\n",
    "};\n",
    "```\n",
    "`EnvFunc` 是 `EnvFuncNode` 的引用类型，它提供了一种方法来调用内部存储的函数。这个类继承自 `ObjectRef` 类，并提供了以下功能：\n",
    "\n",
    "1. 构造函数：`EnvFunc()` 和 `explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}`。这两个构造函数分别用于创建空的 `EnvFunc` 对象和用给定的 `ObjectPtr<Object>` 初始化的 `EnvFunc` 对象。\n",
    "2. 获取内部全局函数指针：`const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }`。这个方法返回指向内部全局函数指针的常量指针。\n",
    "3. 调用函数：`template <typename... Args> runtime::TVMRetValue operator()(Args&&... args) const`。这个方法接受一系列参数，并使用这些参数调用内部全局函数。它返回内部全局函数的返回值。\n",
    "4. 根据名称获取全局函数：`TVM_DLL static EnvFunc Get(const String& name);`。这个方法根据给定的名称在全局环境中查找并返回对应的全局函数。\n",
    "5. 指定容器节点：`using ContainerType = EnvFuncNode;`。这行代码声明了 `EnvFunc` 类可以作为 `EnvFuncNode` 类型的容器节点。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "总的来说，`EnvFuncNode` 和 `EnvFunc` 提供了一种机制，可以将 TVM 中的函数封装为可序列化的全局环境，并提供了方便的方法来调用这些函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `TypedEnvFunc`\n",
    "\n",
    "```c++\n",
    "/*!\n",
    " * \\brief Please refer to \\ref TypedEnvFuncAnchor \"TypedEnvFunc<R(Args..)>\"\n",
    " */\n",
    "template <typename FType>\n",
    "class TypedEnvFunc;\n",
    "\n",
    "/*!\n",
    " * \\anchor TypedEnvFuncAnchor\n",
    " * \\brief A typed version of EnvFunc.\n",
    " * It is backed by a GlobalFuncNode internally.\n",
    " *\n",
    " * \\tparam R The return value of the function.\n",
    " * \\tparam Args The argument signature of the function.\n",
    " * \\sa EnvFunc\n",
    " */\n",
    "template <typename R, typename... Args>\n",
    "class TypedEnvFunc<R(Args...)> : public ObjectRef {\n",
    " public:\n",
    "  /*! \\brief short hand for this function type */\n",
    "  using TSelf = TypedEnvFunc<R(Args...)>;\n",
    "  TypedEnvFunc() {}\n",
    "  explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}\n",
    "  /*!\n",
    "   * \\brief Assign global function to a TypedEnvFunc\n",
    "   * \\param other Another global function.\n",
    "   * \\return reference to self.\n",
    "   */\n",
    "  TSelf& operator=(const EnvFunc& other) {\n",
    "    ObjectRef::operator=(other);\n",
    "    return *this;\n",
    "  }\n",
    "  /*! \\return The internal global function pointer */\n",
    "  const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }\n",
    "  /*!\n",
    "   * \\brief Invoke the function.\n",
    "   * \\param args The arguments\n",
    "   * \\returns The return value.\n",
    "   */\n",
    "  R operator()(Args... args) const {\n",
    "    const EnvFuncNode* n = operator->();\n",
    "    ICHECK(n != nullptr);\n",
    "    return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func,\n",
    "                                                                 std::forward<Args>(args)...);\n",
    "  }\n",
    "  /*! \\brief specify container node */\n",
    "  using ContainerType = EnvFuncNode;\n",
    "};\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了名为 `TypedEnvFunc` 的模板类，它是对 `EnvFunc` 类的泛型版本。`TypedEnvFunc` 类的主要目的是将全局函数封装为类型安全的函数对象。\n",
    "\n",
    "`TypedEnvFunc` 类有两个模板参数：`R` 和 `Args`。其中，`R` 表示函数的返回值类型，`Args` 表示函数的参数类型。`TypedEnvFunc<R(Args...)>` 表示接受 `Args...` 类型参数并返回 `R` 类型的函数对象。\n",
    "\n",
    "`TypedEnvFunc` 类继承自 `ObjectRef` 类，因此它具有引用计数功能。它提供了一些成员函数，如 `operator=`、`operator->` 和 `operator()`，分别用于赋值、获取内部全局函数指针和调用函数。\n",
    "\n",
    "在 `operator()` 函数中，首先通过 `operator-()` 获取内部全局函数指针，然后使用 `runtime::detail::typed_packed_call_dispatcher<R>::run()` 函数调用全局函数，并将结果返回。\n",
    "\n",
    "此外，`TypedEnvFunc` 类还定义了名为 `ContainerType` 的类型别名，用于指定容器节点类型为 `EnvFuncNode`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `EnvFunc` 的实现\n",
    "\n",
    "```c++\n",
    "/*!\n",
    " * \\file env_func.cc\n",
    " */\n",
    "#include <tvm/ir/env_func.h>\n",
    "#include <tvm/runtime/registry.h>\n",
    "#include <tvm/tir/expr.h>\n",
    "\n",
    "namespace tvm {\n",
    "\n",
    "using runtime::PackedFunc;\n",
    "using runtime::TVMArgs;\n",
    "using runtime::TVMRetValue;\n",
    "\n",
    "TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)\n",
    "    .set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {\n",
    "      auto* op = static_cast<const EnvFuncNode*>(node.get());\n",
    "      p->stream << \"EnvFunc(\" << op->name << \")\";\n",
    "    });\n",
    "\n",
    "ObjectPtr<Object> CreateEnvNode(const std::string& name) {\n",
    "  auto* f = runtime::Registry::Get(name);\n",
    "  ICHECK(f != nullptr) << \"Cannot find global function \\'\" << name << '\\'';\n",
    "  ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();\n",
    "  n->func = *f;\n",
    "  n->name = name;\n",
    "  return n;\n",
    "}\n",
    "\n",
    "EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"ir.EnvFuncGet\").set_body_typed(EnvFunc::Get);\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"ir.EnvFuncCall\").set_body([](TVMArgs args, TVMRetValue* rv) {\n",
    "  EnvFunc env = args[0];\n",
    "  ICHECK_GE(args.size(), 1);\n",
    "  env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv);\n",
    "});\n",
    "\n",
    "TVM_REGISTER_GLOBAL(\"ir.EnvFuncGetPackedFunc\").set_body_typed([](const EnvFunc& n) {\n",
    "  return n->func;\n",
    "});\n",
    "\n",
    "TVM_REGISTER_NODE_TYPE(EnvFuncNode)\n",
    "    .set_creator(CreateEnvNode)\n",
    "    .set_repr_bytes([](const Object* n) -> std::string {\n",
    "      return static_cast<const EnvFuncNode*>(n)->name;\n",
    "    });\n",
    "\n",
    "}  // namespace tvm\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码用于处理环境函数。环境函数是一种特殊类型的函数，它们在运行时被调用，而不是在编译时。\n",
    "\n",
    "代码中定义了两个主要的函数：`CreateEnvNode` 和 `EnvFunc::Get`。\n",
    "\n",
    "`CreateEnvNode` 函数接受字符串参数 `name`，这个字符串应该是全局函数的名称。然后，它从注册表中获取这个函数，并创建新的 `EnvFuncNode` 对象，将这个函数和它的名称存储在这个对象中。最后，它返回这个新创建的对象。\n",
    "\n",
    "`EnvFunc::Get` 函数接受字符串参数 `name`，并使用 `CreateEnvNode` 函数来创建对应的 `EnvFuncNode` 对象。然后，它返回这个新创建的对象。\n",
    "\n",
    "此外，代码还注册了几个全局函数，包括 `ir.EnvFuncGet`、`ir.EnvFuncCall` 和 `ir.EnvFuncGetPackedFunc`。这些函数分别用于获取环境函数、调用环境函数和获取环境函数的打包函数。\n",
    "\n",
    "最后，代码还注册了节点类型 `EnvFuncNode`，并设置了它的创建函数和表示函数。创建函数是 `CreateEnvNode`，表示函数是字符串，表示环境函数的名称。"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
